import torch
from torch import nn


# MINE estimator T2 for I(Y;X,A|R)
class CMINE(nn.Module):
    def __init__(self, num_action):
        super(CMINE, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                    nn.ELU(),
                                    nn.Linear(64, 16),
                                    nn.ELU())
        self.layer2 = nn.Sequential(nn.Linear(28 * 28, 64),
                                    nn.ELU(),
                                    nn.Linear(64, 16),
                                    nn.ELU())
        # 16 + 16 + num_action + 1
        self.layer3 = nn.Linear(33 + num_action, 1)

    def forward(self, x, a, y, r):
        x = x.view(1, 28 * 28)
        x = self.layer1(x)

        y = y.view(1, 28 * 28)
        y = self.layer2(y)

        z = torch.cat((x, y, a, r), 1)
        z = self.layer3(z)

        return z


# MINE estimator T3 for I(R;X,A)
class MINE(nn.Module):
    def __init__(self, num_action):
        super(MINE, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                    nn.ELU(),
                                    nn.Linear(64, 16),
                                    nn.ELU())
        self.layer2 = nn.Linear(16 + num_action + 1, 1)

    def forward(self, x, a, r):

        x = x.view(1, 28 * 28)
        x = self.layer1(x)

        z = torch.cat((x, a, r), 1)
        z = self.layer2(z)

        return z
